In [1]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd

Prep Data¶

In [2]:
df = pd.read_csv('/Users/mong275/Downloads/elec_source_target_to_test_animation.csv')
df = df[df.year >= 2020]
df = df[df.scenario == 'rcp45cooler_ssp3']
df
Out[2]:
scenario source target year value
5 rcp45cooler_ssp3 CSP_resource CSP 2020 0.021108
6 rcp45cooler_ssp3 CSP_resource CSP 2025 0.021111
7 rcp45cooler_ssp3 CSP_resource CSP 2030 0.021121
8 rcp45cooler_ssp3 CSP_resource CSP 2035 0.021197
9 rcp45cooler_ssp3 CSP_resource CSP 2040 0.021341
... ... ... ... ... ...
6856 rcp45cooler_ssp3 elect_td_trn trn_pass_road_LDV_4W 2080 1.020281
6857 rcp45cooler_ssp3 elect_td_trn trn_pass_road_LDV_4W 2085 1.116455
6858 rcp45cooler_ssp3 elect_td_trn trn_pass_road_LDV_4W 2090 1.226127
6859 rcp45cooler_ssp3 elect_td_trn trn_pass_road_LDV_4W 2095 1.386638
6860 rcp45cooler_ssp3 elect_td_trn trn_pass_road_LDV_4W 2100 1.549035

1132 rows × 5 columns

Prep Unique Node Labels¶

In [3]:
# Create a list of unique categories (both source and target categories)
unique_categories = pd.concat([df['source'], df['target']]).unique()

print(unique_categories)
['CSP_resource' 'nuclearFuelGenIII' 'nuclearFuelGenII' 'PV_resource'
 'regional biomass' 'regional coal' 'wholesale gas'
 'onshore wind resource' 'offshore wind resource' 'biomass (IGCC CCS)'
 'biomass (IGCC)' 'biomass (conv CCS)' 'biomass (conv)' 'coal (IGCC CCS)'
 'coal (IGCC)' 'coal (conv pul CCS)' 'coal (conv pul)' 'gas (CC CCS)'
 'gas (CC)' 'gas (CT)' 'gas (steam)' 'Gen_III' 'Gen_II_LWR' 'CSP'
 'CSP (dry_hybrid)' 'PV' 'wind_offshore' 'biomass' 'coal' 'gas'
 'geothermal' 'hydro' 'nuclear' 'solar' 'wind' 'electricity'
 'elect_td_ind' 'elect_td_bld' 'elect_td_trn' 'biomass liquids' 'cement'
 'comm cooking' 'comm hot water' 'comm lighting' 'comm non-building'
 'comm office' 'comm other' 'comm refrigeration' 'comm ventilation'
 'industrial energy use' 'oil refining' 'resid clothes dryers'
 'resid clothes washers' 'resid computers' 'resid cooking'
 'resid dishwashers' 'resid freezers' 'resid furnace fans'
 'resid hot water' 'resid lighting' 'resid other' 'resid refrigerators'
 'resid televisions' 'trn_pass' 'trn_pass_road_LDV' 'trn_pass_road_LDV_4W']

Prep Dictionary of Node and Node Index¶

In [4]:
# Create a dictionary mapping each category to an index
category_index = {category: idx for idx, category in enumerate(unique_categories)}

print(category_index)
{'CSP_resource': 0, 'nuclearFuelGenIII': 1, 'nuclearFuelGenII': 2, 'PV_resource': 3, 'regional biomass': 4, 'regional coal': 5, 'wholesale gas': 6, 'onshore wind resource': 7, 'offshore wind resource': 8, 'biomass (IGCC CCS)': 9, 'biomass (IGCC)': 10, 'biomass (conv CCS)': 11, 'biomass (conv)': 12, 'coal (IGCC CCS)': 13, 'coal (IGCC)': 14, 'coal (conv pul CCS)': 15, 'coal (conv pul)': 16, 'gas (CC CCS)': 17, 'gas (CC)': 18, 'gas (CT)': 19, 'gas (steam)': 20, 'Gen_III': 21, 'Gen_II_LWR': 22, 'CSP': 23, 'CSP (dry_hybrid)': 24, 'PV': 25, 'wind_offshore': 26, 'biomass': 27, 'coal': 28, 'gas': 29, 'geothermal': 30, 'hydro': 31, 'nuclear': 32, 'solar': 33, 'wind': 34, 'electricity': 35, 'elect_td_ind': 36, 'elect_td_bld': 37, 'elect_td_trn': 38, 'biomass liquids': 39, 'cement': 40, 'comm cooking': 41, 'comm hot water': 42, 'comm lighting': 43, 'comm non-building': 44, 'comm office': 45, 'comm other': 46, 'comm refrigeration': 47, 'comm ventilation': 48, 'industrial energy use': 49, 'oil refining': 50, 'resid clothes dryers': 51, 'resid clothes washers': 52, 'resid computers': 53, 'resid cooking': 54, 'resid dishwashers': 55, 'resid freezers': 56, 'resid furnace fans': 57, 'resid hot water': 58, 'resid lighting': 59, 'resid other': 60, 'resid refrigerators': 61, 'resid televisions': 62, 'trn_pass': 63, 'trn_pass_road_LDV': 64, 'trn_pass_road_LDV_4W': 65}

Build an animation frame for each year¶

In [5]:
# Create a list to store the frames for animation
frames = []

# Create Sankey diagrams for each year
for year in df['year'].unique():

    #filter dataframe to year
    year_data = df[df['year'] == year]
    
    # Create the node labels (combine source and target categories) from complete list
    node_labels = list(unique_categories)
    
    # Create the links (flows between source and target categories). Retrieves the corresponding index based on node
    link_sources = year_data['source'].apply(lambda x: category_index[x]).tolist()
    link_targets = year_data['target'].apply(lambda x: category_index[x]).tolist()
    link_values = year_data['value'].tolist()
    
    # Create the Sankey diagram for the current year
    frame = go.Frame(
        data=[go.Sankey(
            node=dict(
                pad=15,
                thickness=20,
                line=dict(color="black", width=0.5),
                label=node_labels
            ),
            link=dict(
                source=link_sources,
                target=link_targets,
                value=link_values
            )
        )],
        name=str(year)
    )
    frames.append(frame)
In [6]:
print(link_sources)
[0, 1, 3, 4, 4, 4, 4, 5, 5, 6, 6, 6, 7, 8, 9, 10, 11, 12, 13, 15, 17, 18, 19, 21, 23, 24, 25, 26, 27, 28, 29, 31, 32, 33, 34, 35, 35, 35, 36, 36, 37, 37, 37, 37, 37, 37, 37, 37, 36, 36, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 38, 38, 38]
In [7]:
print(link_targets)
[23, 21, 25, 9, 10, 11, 12, 13, 15, 17, 18, 19, 34, 26, 27, 27, 27, 27, 28, 28, 29, 29, 29, 32, 33, 33, 33, 34, 35, 35, 35, 35, 35, 35, 35, 37, 36, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65]
In [12]:
print(link_values)
[0.000512338470933, 4.671176361175375, 8.6141058060614, 1.170197810741578, 4.49691581e-07, 3.04859349898987, 3.0536102e-07, 0.583917610509, 0.389283820643, 1.1661039090327, 0.31974179126453, 0.001448577741377, 13.6815111577, 5.824963979, 0.548203142642125, 2.06801278e-07, 1.217616161966965, 1.1430609000000001e-07, 0.3254783039377, 0.208301605116, 0.7978911239175001, 0.21057627312985, 0.000675885736483, 1.555501603665691, 0.000147350976199, 6.507582059e-05, 5.0157824334433, 5.67801168, 1.765819625716458, 0.5337799090537, 1.009143282783833, 0.96966230918, 1.555501603665691, 5.015994860240089, 15.8038972935, 7.50967311, 10.316003216, 1.63581674, 0.0127287424367, 0.0268907413936, 0.11702812448, 0.350499183431, 0.53085950781, 0.69687496, 0.304222139, 1.28873805, 0.30611342327, 0.46277063703, 10.270105075718, 0.00628066980985, 0.17939548172, 0.012072139662, 0.0308075712, 0.14947238659, 0.07072873883, 0.06175670141, 0.0542639937, 1.113498087262, 0.08151715107, 1.19606613, 0.28389289079, 0.219095358, 0.06356972111, 0.0232124428, 1.549034906]

Create Animation¶

In [8]:
# Create the initial Sankey diagram (for the first year)
initial_year = df['year'].min()
initial_year_data = df[df['year'] == initial_year]
initial_sources = initial_year_data['source'].apply(lambda x: category_index[x]).tolist()
initial_targets = initial_year_data['target'].apply(lambda x: category_index[x]).tolist()
initial_values = initial_year_data['value'].tolist()

# Create the figure with the first year's data
fig = go.Figure(
    data=[go.Sankey(
        node=dict(
            pad=15,
            thickness=20,
            line=dict(color="black", width=0.5),
            label=node_labels
        ),
        link=dict(
            source=initial_sources,
            target=initial_targets,
            value=initial_values
        )
    )],
    layout=dict(
        title="Sankey Diagram by Year",
        updatemenus=[dict(
            type="buttons",
            showactive=False,
            buttons=[dict(
                label="Play",
                method="animate",
                args=[None, dict(frame=dict(duration=1000, redraw=True), fromcurrent=True)]
            )]
        )],
        sliders=[dict(
            yanchor="top", xanchor="left",
            currentvalue=dict(font=dict(size=20), visible=True, xanchor="right"),
            transition=dict(duration=300),
            steps=[dict(label=str(year), method="animate", args=[[str(year)], dict(mode="immediate", 
                                                                                   frame=dict(duration=0, redraw=True))]) for year in df['year'].unique()]
        )]
    ),
    frames=frames
)

fig.update_layout(
    autosize=False,
    width=1200,
    height=800,
    margin=dict(l=50, r=50, b=100, t=100, pad=4)
)

fig.show()